import numpy as np; np.random.seed(0)
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
import torch

# values = np.random.rand(3, 3)
# x_ticks = ['x-1', 'x-2', 'x-3']
# y_ticks = ['y-1', 'y-2', 'y-3']  # 自定义横纵轴
# ax = sns.heatmap(values, xticklabels=x_ticks, yticklabels=y_ticks, annot=True)
# ax.set_title('Heatmap for test')  # 图标题
# ax.set_xlabel('x label')  # x轴标题
# ax.set_ylabel('y label')
# plt.show()


attention = torch.rand(8,8)

def plot_heat(attention):
    attention = attention.detach()
    N = len(attention)
    x_ticks = []
    for i in range(N):
        name = str(i)
        x_ticks.append(name)
    ax = sns.heatmap(attention, xticklabels=x_ticks, yticklabels=x_ticks, annot=True, linewidths=0.5,vmin=0,vmax=1)
    ax.set_title('Heatmap for test')  # 图标题
    ax.set_xlabel('Agent Index')  # x轴标题
    ax.set_ylabel('Agent Index')
    plt.show()

# plot_heat(attention)
